{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E31Qa-NiFYsY"
      },
      "source": [
        "Copyright 2023 Google LLC.\n",
        "\n",
        "SPDX-License-Identifier: Apache-2.0"
      ]
    },
    {
      "metadata": {},
      "cell_type": "code",
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EHVEDjc_1CwU"
      },
      "source": [
        "# NOTE: This Colab is only for documentation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EIFvVoclwTTN"
      },
      "source": [
        "## **Learning Hybrid Continuous Time Policies**\n",
        "\n",
        "## Definition of [Hybrid Continuous Time (HCT) Policies](https://arxiv.org/abs/2203.08715)\n",
        "\n",
        "Suppose that we observe image(s) $s_t \\in \\mathbb{R}^{H \\times W \\times C}$ at\n",
        "every discrete time index $t \\in \\mathbb{N}$, and in-between time indices\n",
        "$t$ and $t + 1$, we have access to continuous-time (this will be\n",
        "straightforwardly relaxed to “higher-frequency”) state observations from other sensing\n",
        "modalities, denoted by the function $x_t(\\cdot) : \\tau \\in [0,T] \\rightarrow\n",
        "\\mathbb{R}^n$. The variable $\\tau$ is referred to as the \"interpolation\"\n",
        "time, indexing the continuous time between image observations $s_t$ and\n",
        "$s_{t+1}$. Thus, we choose an arbitrary interval $[0, T]$ to represent its\n",
        "domain, with $\\tau = 0$ coinciding with discrete time index $t$, and\n",
        "$\\tau = T$ coinciding with discrete time index $t+1$.\n",
        "\n",
        "An HCT policy is defined as a functional map from the observation tuple $o_t :=\n",
        "(s_t,x_t(\\cdot))$ to a control function $u_t(\\cdot): \\tau \\in [0,T]\n",
        "\\rightarrow U$, where $U$ is the control space. For ease of notation, assume\n",
        "that $U = \\mathbb{R}^m$. Therefore, in “MDP-notation”, our action \"$a_t$\",\n",
        "mapped from $o_t$, is the control function $u_t(\\cdot)$, drawing natural\n",
        "analogies with hierarchical policies.\n",
        "\n",
        "\u003e An HCT Policy therefore is a *functional* map from $o_t$ to $u_t$. Since\n",
        "\u003e the policy must be realizable with incoming observations, this map must\n",
        "\u003e additionally be causal."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yLShkwhQ0mSN"
      },
      "source": [
        "## Working with Discrete-Time Measurements\n",
        "\n",
        "While the functional representation will be useful in designing the\n",
        "architectures, in actuality we receive observations and generate actions at\n",
        "fixed frequencies. Thus, we introduce some additional notation.\n",
        "\n",
        "Assume that over the interval $\\tau \\in [0, T)$, we output $M \u003e 1$ equally\n",
        "spaced actions at $\\tau_0 = 0, \\tau_1 = \\frac{T}{M}, \\ldots, \\tau_{M-1} =\n",
        "\\frac{(M-1)T}{M}$. Recall that at $\\tau = T$, we reset $\\tau = 0$,\n",
        "coinciding with the arrival of the next image observation $s_{t+1}$. Thus, the control *frequency* is $M$ times the image-observation frequency. Let\n",
        "$\\mathbf{u}_t$ denote the set of actions $\\{u_t(\\tau_0), \\ldots,\n",
        "u_t(\\tau_{M-1})\\}$.\n",
        "\n",
        "Similarly, we assume that the frequency of the observations of the signal $x_t(\\cdot)$ is $N$ times the control frequency, where $N \\geq 1$. To notate this, we define $\\mathbf{x}_t^i$ to be the set of $N+1$ equally spaced observations of $x_t(\\cdot)$ within the interval $[\\tau_{i-1}, \\tau_i]$, for $i = 1,\\ldots, M-1$. That is, $\\mathbf{x}_t^{i} =\n",
        "\\{x_t(\\tau_{i-1}), x_t(\\tau_{i-1}+\\frac{T}{MN}),\\ldots, x_t(\\tau_{i}) \\}$.\n",
        "\n",
        "See the figure below for an illustration of these variables. Note that for ease of batching data, we additionally define $\\mathbf{x}_t^0$ as the set of $N+1$ high-frequency state observations in-between $u_{t-1}(\\tau_{M-1})$ and $u_t(\\tau_0)$.\n",
        "\n",
        "![HybridContinuousTimeStructure](InFuser.png)\n",
        "\n",
        "By the causal assumption, our architecture must conform to the following\n",
        "functional relations:\n",
        "\n",
        "$$\n",
        "\\begin{eqnarray}\n",
        "(s_t, \\mathbf{x}_t^0) \u0026\\rightarrow \u0026u_t(0) \\\\\n",
        "(s_t, \\mathbf{x}_t^{0}, \\ldots, \\mathbf{x}_t^{j}) \u0026\\rightarrow \u0026u_t(\\tau_j), \\text{ for } j = 1, \\ldots, M-1\n",
        "\\end{eqnarray}\n",
        "$$"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nWWTcSWcCmHu"
      },
      "source": [
        "# NDP"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ccY3iFwE1jU8"
      },
      "source": [
        "This is an adaptation of the\n",
        "[Neural Dynamic Policies](https://shikharbahl.github.io/neural-dynamic-policies/resources/ndp.pdf)\n",
        "paper. The key aspect of this model is that it only uses the observations\n",
        "$s_t$ and $x_t(0)$, and generates the control function $u_t(\\cdot)$ for\n",
        "the entire interval $\\tau \\in [0, T]$ open-loop by solving a parametric 2nd\n",
        "order ODE.\n",
        "\n",
        "NOTE: For this model, we set $T=1$. Thus, the $M$ actions are output at\n",
        "$\\tau \\in \\{0, \\frac{1}{M}, \\ldots, \\frac{M-1}{M}\\}$."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U1PYfn5-1ny7"
      },
      "source": [
        "A\n",
        "[Dynamic Movement Primitive](https://homes.cs.washington.edu/~todorov/courses/amath579/reading/DynamicPrimitives.pdf)\n",
        "(DMP) is defined by the following parametric coupled 2nd order ODE:\n",
        "\n",
        "$$\n",
        "\\begin{eqnarray} \n",
        "\\dfrac{d^2 u_t(\\tau)}{d \\tau^2} \u0026:= \u0026\\alpha_u \\left(\\beta\n",
        "(g_t - u_t(\\tau)) - \\dfrac{d u_t(\\tau)}{d \\tau}\\right) + f_t(\\phi_t(\\tau)),\n",
        "\\quad \\tau \\in [0, 1)  \\\\\n",
        "\\dfrac{d \\phi_t(\\tau)}{d \\tau} \u0026:=\n",
        "\u0026-\\alpha_\\phi \\phi_t(\\tau), \\quad \\tau \\in [0, 1).\n",
        "\\end{eqnarray}\n",
        "$$\n",
        "\n",
        "where $\\alpha_u, \\alpha_\\phi, \\beta \\in \\mathbb{R}$ are\n",
        "positive constants, and $\\phi_t$ is the phase function with initial condition\n",
        "$\\phi_t(0) = 1$. The function $f_t$, known as the forcing function, takes\n",
        "the form:\n",
        "\n",
        "$$\n",
        "f_t(\\phi) = \\dfrac{\\phi}{\\sum_{k=1}^K \\psi_k(\\phi)} (W_t \\psi(\\phi)) \\circ (g_t - u_t(0))\n",
        "$$\n",
        "\n",
        "where $\\psi_k$ is a Gaussian radial basis function and $\\psi(\\phi) :=\n",
        "(\\psi_1(\\phi), \\ldots, \\psi_K(\\phi))$ where $K$ is the number of basis\n",
        "functions. The parameters defining this system of equations are the \"goal\n",
        "vector\" $g_t \\in \\mathbb{R}^m$ and the \"weights matrix\" $W_t \\in\n",
        "\\mathbb{R}^{m \\times K}$."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "beO8MQWx2S9R"
      },
      "source": [
        "The architecture is defined by three modules:\n",
        "\n",
        "*   an encoder, that maps the observations $(s_t, x_t(0))$ to an image embedding $z_{s_t}$ and DMP parameters $\\{g_t, W_t\\}$; hence *neural dynamic* policies,\n",
        "*   a decoder, that maps $(z_{s_t}, x_t(0))$ to the DMP ODE's initial conditions $\\left(u_t(0), \\frac{d u_t(0)}{d\\tau}\\right)$, and\n",
        "*   the DMP ODEs defined above. The constants $\\{\\alpha_u, \\alpha_\\phi, \\beta, K\\}$ are left as hyper-parameters."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Hbn4HEP2Y_Q"
      },
      "source": [
        "## [Flax](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html) Module\n",
        "\n",
        "To define an NDP model, we initialize an `NDP` module object defined in\n",
        "`ndp_model.py`. Please see the module header for all relevant definitions. We\n",
        "outline certain key forward-pass methods."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Oznle8Sc2rKI"
      },
      "source": [
        "## Computing the Flow\n",
        "\n",
        "There are three ways of computing the \"flow,\" i.e., the solution of the NDP to\n",
        "obtain the control actions between time-steps $t$ and $t+1$.\n",
        "\n",
        "1.  (Batched) Compute all $M$ actions $\\mathbf{u}_t$ from $(s_t, x_t(0))$\n",
        "    using:\n",
        "\n",
        "    ```python\n",
        "    model.apply(params, batch_images, batch_hf_obs)\n",
        "    ```\n",
        "\n",
        "    where\n",
        "\n",
        "    *   `params` are the module parameters,\n",
        "    *   `batch_images` is a batch of image tensors ($s_t$), and\n",
        "    *   `batch_hf_obs` is a batch of high-frequency observations ($x_t(0)$).\n",
        "\n",
        "    Each instance in the batched output has shape $M \\times m$.\n",
        "\n",
        "2.  (Batched) Compute the solution $u_t(\\cdot)$ at a dense set of times $\\tau\n",
        "    \\in [0, 1]$, given by the vector `pred_times` using:\n",
        "\n",
        "    ```python\n",
        "    model.apply(params, batch_images, batch_hf_obs, pred_times,\n",
        "                method=ndp_model.compute_ndp_flow)\n",
        "    ```\n",
        "\n",
        "    Each instance in the batched output has shape `len(pred_times)`$\\times m$.\n",
        "\n",
        "3.  (Unbatched) Compute each action in the sequence iteratively, i.e., as a\n",
        "    policy.\n",
        "\n",
        "    ```python\n",
        "    # Extract the step functions for the NDP model\n",
        "    re_init, step_fwd = model.step_functions\n",
        "\n",
        "    # Given a new (image, hf_obs) pair, compute the initial action and NDP params\n",
        "    ndp_state, ndp_args = re_init(params, image, hf_obs)\n",
        "    # u(0) = ndp_state[:model.action_dim]\n",
        "\n",
        "    # Compute u(tau_1)...u(tau_{M-1}) step-by-step:\n",
        "    tau = 0.\n",
        "    for i in range(1, M):\n",
        "      ndp_state, tau = step_fwd(params, ndp_state, tau, ndp_args)\n",
        "      # u(tau_i) = ndp_state[:model.action_dim]\n",
        "    ```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rk8TrU-I24rm"
      },
      "source": [
        "## Training via Imitation Learning\n",
        "\n",
        "Let $\\hat{\\mathbf{u}}_t = \\{\\hat{u}_t(0), \\ldots, \\hat{u}_t(\\frac{M-1}{M})\\}$\n",
        "be the observed sequence of actions between discrete time-steps $t$ and\n",
        "$t+1$. Let us define its linear interpolant as the function\n",
        "$\\hat{u}_t(\\cdot)$, and the control function generated by the NDP model\n",
        "(conditioned on observations) as $u_{t,\\theta}(\\cdot)$, where $\\theta$\n",
        "represents all learnable parameters. We define the imitation loss as the\n",
        "integral:\n",
        "\n",
        "$$\n",
        "I_t(\\theta) := \\int_{0}^{\\frac{M-1}{M}} l(\\hat{u}_t(\\tau), u_{t,\\theta}(\\tau))\\ d\\tau,\n",
        "$$\n",
        "\n",
        "where $l(\\cdot, \\cdot) \\mapsto \\mathbb{R}$ is some loss function penalizing\n",
        "the observed and true actions. Note that this integral is equal to the terminal\n",
        "value of the following auxiliary ODE:\n",
        "\n",
        "$$\n",
        "  \\dfrac{d J(\\tau)}{d\\tau} = l(\\hat{u}_t(\\tau), u_{t,\\theta}(\\tau)), \\quad \\tau \\in [0, \\frac{M-1}{M}],\\ J(0) = 0.\n",
        "$$\n",
        "\n",
        "To compute the loss, we solve an augmented set of ODEs which includes the\n",
        "NDP-ODE as well as the auxiliary \"cost-ODE\" defined above.\n",
        "\n",
        "Given a batched set of observations `batch_images`, `batch_hf_obs` and true\n",
        "actions `batch_true_actions`, where each instance in `batch_true_actions` has\n",
        "shape $M \\times m$, we can compute the vector of losses across the batch as:\n",
        "\n",
        "```python\n",
        "batch_pred_actions, batch_losses = model.apply(\n",
        "  params, batch_images, batch_hf_obs, batch_true_actions,\n",
        "  method=ndp_model.compute_augmented_flow)\n",
        "```\n",
        "\n",
        "Here `batch_pred_actions` are the batch of predicted actions, and `batch_losses`\n",
        "is the vector of losses across the batch. These can be averaged to yield a final\n",
        "loss value which can be auto-diffed and incorporated into any training pipeline.\n",
        "This library uses\n",
        "[`flax` ](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState)\n",
        "along with [`optax`](https://github.com/deepmind/optax); see `ndp_utils.py`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vqbay_Fk1oFb"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "16L1qsmG0B-HQWZYn_A-iIXDP6iz2YKCG",
          "timestamp": 1680100608913
        },
        {
          "file_id": "1cOK4xh3_IshDQz6GjMDG3prg7b9Xdb6G",
          "timestamp": 1679955209952
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
